{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "65b3aadc-c540-4cb2-a338-d523d3f22e5b",
   "metadata": {},
   "source": [
    "Unit test generator using GPT, Claude and Gemini.\n",
    "This will create unit test code from python and also run the code and provide the result (including any errors)\n",
    "Note:\n",
    "When I tried to use claude-sonnet-4-20250514 the results were too big and the python was cut-off (no matter how big I made the max tokens).  This seemed to be the case for both examples.  I've changed it to claude-3-5-sonnet-20240620 and it seems to be run better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e610bf56-a46e-4aff-8de1-ab49d62b1ad3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import requests\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import google.generativeai\n",
    "import anthropic\n",
    "from IPython.display import Markdown, display, update_display\n",
    "import gradio as gr\n",
    "import sys\n",
    "import io\n",
    "import traceback\n",
    "import unittest\n",
    "import subprocess\n",
    "import tempfile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f672e1c-87e9-4865-b760-370fa605e614",
   "metadata": {},
   "outputs": [],
   "source": [
    "# keys\n",
    "\n",
    "load_dotenv(override=True)\n",
    "openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n",
    "\n",
    "if openai_api_key:\n",
    "    print(\"All good\")\n",
    "else:\n",
    "    print(\"OpenAI key issue\")\n",
    "\n",
    "claude_api_key = os.getenv(\"ANTHROPIC_API_KEY\")\n",
    "\n",
    "if claude_api_key:\n",
    "    print(\"All good\")\n",
    "else:\n",
    "    print(\"Claude key issue\")\n",
    "\n",
    "google_api_key = os.getenv(\"GOOGLE_API_KEY\")\n",
    "\n",
    "if google_api_key:\n",
    "    print(\"All good\")\n",
    "else:\n",
    "    print(\"Google key issue\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8aa149ed-9298-4d69-8fe2-8f5de0f667da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialise\n",
    "\n",
    "openai = OpenAI()\n",
    "claude = anthropic.Anthropic()\n",
    "google.generativeai.configure()\n",
    "\n",
    "OPENAI_MODEL = \"gpt-4o\"\n",
    "CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\" #\"claude-sonnet-4-20250514\"\n",
    "GOOGLE_MODEL = \"gemini-2.0-flash\"\n",
    "\n",
    "max_tok = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6896636f-923e-4a2c-9d6c-fac07828a201",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"You are an engineer with responsibility for unit testing python code.\"\n",
    "system_message += \"You review base python code and develop unit tests, also in python, which validate each unit of code.\"\n",
    "system_message += \"\"\" The output must be in Python with both the unit tests and comments explaining the purpose of each test.\n",
    "The output should not include any additional text at the start or end including \"```\".  It should be possible to run the code without any updates including an execution statement.\n",
    "Include the base / original python code in the response.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e7b3546-57aa-4c29-bc5d-f211970d04eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def user_prompt_for(python):\n",
    "    user_prompt = \"Review the Python code provided and develop unit tests which can be run in a jupyter lab.\"\n",
    "    user_prompt += \"\"\" The output must be in Python with both the unit tests and comments explaining the purpose of each test.\n",
    "The output should not include any additional text at the start or end including \"```\".  It should be possible to run the code without any updates (include an execution statement).\n",
    "Include the base / original python code in the response.\"\"\"\n",
    "    user_prompt += python\n",
    "    return user_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6190659-f54c-4951-bef4-4960f8e51cc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def messages_for(python):\n",
    "    return [\n",
    "        {\"role\": \"system\", \"content\": system_message},\n",
    "        {\"role\": \"user\", \"content\": user_prompt_for(python)}\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b327aa3-3277-44e1-972f-aa7158147ddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# python example\n",
    "example = \"\"\"class BookNotAvailableError(Exception):\n",
    "    pass\n",
    "\n",
    "class Library:\n",
    "    def __init__(self):\n",
    "        self.inventory = {}  # book title -> quantity\n",
    "        self.borrowed = {}   # user -> list of borrowed book titles\n",
    "\n",
    "    def add_book(self, title, quantity=1):\n",
    "        if quantity <= 0:\n",
    "            raise ValueError(\"Quantity must be positive\")\n",
    "        self.inventory[title] = self.inventory.get(title, 0) + quantity\n",
    "\n",
    "    def borrow_book(self, user, title):\n",
    "        if self.inventory.get(title, 0) < 1:\n",
    "            raise BookNotAvailableError(f\"'{title}' is not available\")\n",
    "        self.inventory[title] -= 1\n",
    "        self.borrowed.setdefault(user, []).append(title)\n",
    "\n",
    "    def return_book(self, user, title):\n",
    "        if user not in self.borrowed or title not in self.borrowed[user]:\n",
    "            raise ValueError(f\"User '{user}' did not borrow '{title}'\")\n",
    "        self.borrowed[user].remove(title)\n",
    "        self.inventory[title] = self.inventory.get(title, 0) + 1\n",
    "\n",
    "    def get_available_books(self):\n",
    "        return {title: qty for title, qty in self.inventory.items() if qty > 0}\n",
    "\n",
    "    def get_borrowed_books(self, user):\n",
    "        return self.borrowed.get(user, [])\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed6e624e-88a5-4f10-8ab5-f071f0ca3041",
   "metadata": {},
   "outputs": [],
   "source": [
    "# python example2\n",
    "example2 = \"\"\"class Calculator:\n",
    "    def add(self, a, b):\n",
    "        return a + b\n",
    "\n",
    "    def subtract(self, a, b):\n",
    "        return a - b\n",
    "\n",
    "    def divide(self, a, b):\n",
    "        if b == 0:\n",
    "            raise ValueError(\"Cannot divide by zero\")\n",
    "        return a / b\n",
    "\n",
    "    def multiply(self, a, b):\n",
    "        return a * b\n",
    "\n",
    "\n",
    "def is_prime(n):\n",
    "    if n <= 1:\n",
    "        return False\n",
    "    if n <= 3:\n",
    "        return True\n",
    "    if n % 2 == 0 or n % 3 == 0:\n",
    "        return False\n",
    "    i = 5\n",
    "    while i * i <= n:\n",
    "        if n % i == 0 or n % (i + 2) == 0:\n",
    "            return False\n",
    "        i += 6\n",
    "    return True\n",
    "    \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7d2fea8-74c6-4421-8f1e-0e76d5b201b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unit_test_gpt(python):    \n",
    "    stream = openai.chat.completions.create(model=OPENAI_MODEL, messages=messages_for(python), stream=True)\n",
    "    reply = \"\"\n",
    "    for chunk in stream:\n",
    "        fragment = chunk.choices[0].delta.content or \"\"\n",
    "        reply += fragment\n",
    "        yield reply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd84ad8-d55c-4fe0-9eeb-1895c95c4a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unit_test_claude(python):\n",
    "    result = claude.messages.stream(\n",
    "        model=CLAUDE_MODEL,\n",
    "        max_tokens=max_tok,\n",
    "        system=system_message,\n",
    "        messages=[{\"role\": \"user\", \"content\": user_prompt_for(python)}],\n",
    "    )\n",
    "    reply = \"\"\n",
    "    with result as stream:\n",
    "        for text in stream.text_stream:\n",
    "            reply += text\n",
    "            yield reply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad86f652-879a-489f-9891-bdc2d97c33b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unit_test_google(python):\n",
    "    model = google.generativeai.GenerativeModel(\n",
    "        model_name=GOOGLE_MODEL,\n",
    "        system_instruction=system_message\n",
    "    )\n",
    "    stream = model.generate_content(contents=user_prompt_for(python),stream=True)\n",
    "    reply = \"\"\n",
    "    for chunk in stream:\n",
    "        reply += chunk.text or \"\"\n",
    "        yield reply.replace(\"```python\\n\", \"\").replace(\"```\", \"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105db6f9-343c-491d-8e44-3a5328b81719",
   "metadata": {},
   "outputs": [],
   "source": [
    "#unit_test_gpt(example)\n",
    "#unit_test_claude(example)\n",
    "#unit_test_google(example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f1ae8f5-16c8-40a0-aa18-63b617df078d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_model(python, model):\n",
    "    if model==\"GPT\":\n",
    "        result = unit_test_gpt(python)\n",
    "    elif model==\"Claude\":\n",
    "        result = unit_test_claude(python)\n",
    "    elif model==\"Google\":\n",
    "        result = unit_test_google(python)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown model\")\n",
    "    for stream_so_far in result:\n",
    "        yield stream_so_far        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1ddb38e-6b0a-4c37-baa4-ace0b7de887a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# with gr.Blocks() as ui:\n",
    "#     with gr.Row():\n",
    "#         python = gr.Textbox(label=\"Python code:\", lines=10, value=example)\n",
    "#         test = gr.Textbox(label=\"Unit tests\", lines=10)\n",
    "#     with gr.Row():\n",
    "#         model = gr.Dropdown([\"GPT\", \"Claude\",\"Google\"], label=\"Select model\", value=\"GPT\")\n",
    "#         generate = gr.Button(\"Generate unit tests\")\n",
    "\n",
    "#     generate.click(select_model, inputs=[python, model], outputs=[test])\n",
    "\n",
    "# ui.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "389ae411-a4f6-44f2-8b26-d46a971687a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def execute_python(code):\n",
    "    # Capture stdout and stderr\n",
    "    output = io.StringIO()\n",
    "    sys_stdout = sys.stdout\n",
    "    sys_stderr = sys.stderr\n",
    "    sys.stdout = output\n",
    "    sys.stderr = output\n",
    "\n",
    "    try:\n",
    "        # Compile the code first\n",
    "        compiled_code = compile(code, '<string>', 'exec')\n",
    "\n",
    "        # Prepare a namespace dict for exec environment\n",
    "        # Include __builtins__ so imports like 'import unittest' work\n",
    "        namespace = {\"__builtins__\": __builtins__}\n",
    "\n",
    "        # Run the user's code, but expect tests will be defined here\n",
    "        exec(compiled_code, namespace)\n",
    "\n",
    "        # Look for unittest.TestCase subclasses in the namespace\n",
    "        loader = unittest.TestLoader()\n",
    "        suite = unittest.TestSuite()\n",
    "\n",
    "        for obj in namespace.values():\n",
    "            if isinstance(obj, type) and issubclass(obj, unittest.TestCase):\n",
    "                tests = loader.loadTestsFromTestCase(obj)\n",
    "                suite.addTests(tests)\n",
    "\n",
    "        # Run the tests\n",
    "        runner = unittest.TextTestRunner(stream=output, verbosity=2)\n",
    "        result = runner.run(suite)\n",
    "\n",
    "    except SystemExit as e:\n",
    "        # Catch sys.exit calls from unittest.main()\n",
    "        output.write(f\"\\nSystemExit called with code {e.code}\\n\")\n",
    "    except Exception as e:\n",
    "        # Catch other errors\n",
    "        output.write(f\"\\nException: {e}\\n\")\n",
    "    finally:\n",
    "        sys.stdout = sys_stdout\n",
    "        sys.stderr = sys_stderr\n",
    "\n",
    "    return output.getvalue()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eca98de3-9e2f-4c23-8bb4-dbb2787a15a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as ui:\n",
    "    with gr.Row():\n",
    "        python = gr.Textbox(label=\"Python code:\", lines=10, value=example2)\n",
    "        test = gr.Textbox(label=\"Unit tests\", lines=10)\n",
    "        test_run = gr.Textbox(label=\"Test results\", lines=10)\n",
    "    with gr.Row():\n",
    "        model = gr.Dropdown([\"GPT\", \"Claude\",\"Google\"], label=\"Select model\", value=\"GPT\")\n",
    "        generate = gr.Button(\"Generate unit tests\")\n",
    "        run = gr.Button(\"Run unit tests\")\n",
    "\n",
    "    generate.click(select_model, inputs=[python, model], outputs=[test])\n",
    "    run.click(execute_python, inputs=[test],outputs=[test_run])\n",
    "\n",
    "ui.launch()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
